Enable training strategy for Indexer#3415
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
27eb0b4 to
a99586d
Compare
a81b8fd to
85e4e0d
Compare
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
85e4e0d to
b0f353b
Compare
There was a problem hiding this comment.
This PR enables the selective parameter training strategy (dense warm-up and sparse training stages) for the DeepSeek V3.2 Indexer. It refactors parameter freezing flags and adds tests to verify proper isolation of indexer gradients from the rest of the model.
🔍 General Feedback
- Memory Optimization in Selective Training: The current implementation of optimizer masking computes and stores Adam state parameters for the entire model before zeroing out the updates. I've suggested an explicit mapping with
optax.multi_transformto avoid allocating massive memory blocks for frozen parameter states, which is critical for 671B model scaling. - Gradient Isolation in KL Divergence: I left an inline comment pointing out a gradient leak when calculating the KL divergence in
calculate_indexer_loss. Ensurejax.lax.stop_gradientis applied to the targetattention_probsdistribution, so that the main model's queries and keys do not get updated by the indexer's loss.
4ec8a1e to
906f12f
Compare
shuningjin
left a comment
There was a problem hiding this comment.
Thanks! Sparse training looks good to me. I left a suggestion on dense warmup, along with minor comments. Will take another look at trainable_parameters_mask soon.
be3bad4 to
d3defd7
Compare
shuningjin
left a comment
There was a problem hiding this comment.
Looks great! Thanks for the change.
d3defd7 to
4b251a3
Compare
Rohan-Bierneni
left a comment
There was a problem hiding this comment.
Thank you for the changes I have left a small nit.
f27e777 to
fef4e96
Compare
fef4e96 to
0b55a28
Compare
Description
Enable selective parameter training strategy for DeeSeek V3.2 Indexer - paper
trainable_parameters_maskflag, allowing specific parameters to be targeted for training while freezing the rest of the model.TrainableParametersMaskTestunit tests for validation.indexer_sparse_trainingflag to indicate Dense Warm-up stage or Sparse Training stage for DS v3.2.test_indexer_gradientsunit test to verify proper gradient isolation.use_sparse_indexer-->use_indexer;index_head_dim-->indexer_head_dim;index_n_heads-->indexer_n_heads, andindex_topk-->indexer_topkTests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.